{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "from keras.models import Model\n",
    "from keras.layers import Input\n",
    "from keras.layers.convolutional import Conv2D\n",
    "from keras.layers.pooling import MaxPooling2D, AveragePooling2D\n",
    "from keras.layers.normalization import BatchNormalization\n",
    "from keras import backend as K\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def format_decimal(arr, places=6):\n",
    "    return [round(x * 10**places) / 10**places for x in arr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "DATA = OrderedDict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### pipeline 9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "random_seed = 1009\n",
    "data_in_shape = (8, 8, 2)\n",
    "\n",
    "layers = [\n",
    "    Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True),\n",
    "    AveragePooling2D(pool_size=(2,2), strides=None, padding='valid', data_format='channels_last')\n",
    "]\n",
    "\n",
    "input_layer = Input(shape=data_in_shape)\n",
    "x = layers[0](input_layer)\n",
    "for layer in layers[1:-1]:\n",
    "    x = layer(x)\n",
    "output_layer = layers[-1](x)\n",
    "model = Model(inputs=input_layer, outputs=output_layer)\n",
    "\n",
    "np.random.seed(random_seed)\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(random_seed + i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "\n",
    "DATA['pipeline_09'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### export for Keras.js tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "filename = '../../test/data/pipeline/09.json'\n",
    "if not os.path.exists(os.path.dirname(filename)):\n",
    "    os.makedirs(os.path.dirname(filename))\n",
    "with open(filename, 'w') as f:\n",
    "    json.dump(DATA, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"pipeline_09\": {\"input\": {\"data\": [0.229761, -0.670851, -0.398017, -0.453102, 0.590253, 0.096484, -0.035974, 0.69389, -0.260438, 0.250677, -0.193674, 0.385879, -0.991763, -0.558061, -0.750891, 0.265556, 0.657619, -0.801723, -0.910484, 0.975926, 0.288351, -0.364498, -0.126447, 0.616925, 0.303193, -0.820024, 0.429665, 0.070106, -0.34556, -0.459221, -0.798301, 0.998903, 0.359268, -0.957685, -0.915632, 0.691774, 0.25984, 0.879641, 0.179776, -0.423543, -0.867718, -0.560873, -0.382501, -0.365361, -0.723601, -0.033603, 0.789498, 0.455418, 0.827951, 0.872766, -0.595781, 0.655746, -0.100867, -0.407415, -0.5028, -0.07739, -0.432834, 0.371304, -0.410596, -0.92421, 0.688391, -0.862805, 0.206217, -0.602252, -0.132926, -0.676393, -0.280987, 0.247491, -0.027147, -0.016659, -0.23147, -0.653909, 0.876245, -0.853047, 0.6974, 0.939353, 0.068803, -0.786326, -0.10847, -0.038066, 0.594608, -0.631675, -0.052451, 0.362388, 0.573359, -0.461688, -0.799513, 0.883652, -0.681072, 0.708611, -0.103832, 0.161348, -0.938899, -0.604966, 0.588076, -0.410424, 0.215717, -0.211664, 0.012579, -0.756258, 0.346215, -0.858999, 0.517538, -0.867138, -0.49847, -0.114061, 0.906122, -0.446741, -0.632029, 0.913293, -0.420459, 0.031947, -0.34244, 0.231598, 0.806704, 0.479821, 0.133654, -0.915619, 0.113133, -0.995506, -0.979647, 0.571933, 0.161284, 0.035078, 0.210916, 0.504783, 0.667269, 0.356783], \"shape\": [8, 8, 2]}, \"weights\": [{\"data\": [0.229761, -0.670851, -0.398017, -0.453102, 0.590253, 0.096484, -0.035974, 0.69389, -0.260438, 0.250677, -0.193674, 0.385879, -0.991763, -0.558061, -0.750891, 0.265556, 0.657619, -0.801723, -0.910484, 0.975926, 0.288351, -0.364498, -0.126447, 0.616925, 0.303193, -0.820024, 0.429665, 0.070106, -0.34556, -0.459221, -0.798301, 0.998903, 0.359268, -0.957685, -0.915632, 0.691774, 0.25984, 0.879641, 0.179776, -0.423543, -0.867718, -0.560873, -0.382501, -0.365361, -0.723601, -0.033603, 0.789498, 0.455418, 0.827951, 0.872766, -0.595781, 0.655746, -0.100867, -0.407415, -0.5028, -0.07739, -0.432834, 0.371304, -0.410596, -0.92421, 0.688391, -0.862805, 0.206217, -0.602252, -0.132926, -0.676393, -0.280987, 0.247491, -0.027147, -0.016659, -0.23147, -0.653909], \"shape\": [3, 3, 2, 4]}, {\"data\": [-0.211487, -0.648815, -0.854588, -0.616238], \"shape\": [4]}], \"expected\": {\"data\": [0.509739, 0.237036, 0.293895, 0.282399, 0.423908, 0.008688, 0.304865, 0.666234, 0.022924, 0.541168, 1.17086, 0.0, 0.203349, 0.417828, 0.0, 0.398297, 0.175827, 0.763362, 0.328252, 0.0, 0.780266, 0.128615, 0.007518, 0.537463, 0.079002, 0.150453, 0.0, 0.507228, 0.436817, 0.946573, 0.496548, 0.50641, 1.045505, 0.082524, 0.213519, 0.0], \"shape\": [3, 3, 4]}}}\n"
     ]
    }
   ],
   "source": [
    "print(json.dumps(DATA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
